{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "279f0229",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2021 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32e091de",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
    "\n",
    "# Training HugeCTR Model with Pre-trained Embeddings\n",
    "\n",
    "In this notebook, we will train a deep neural network for predicting user's rating (binary target with 1 for ratings `>3` and 0 for  ratings `<=3`). The two categorical features are `userId` and `movieId`.\n",
    "\n",
    "We will also make use of movie's pretrained embeddings, extracted in the previous notebooks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1cf0042",
   "metadata": {},
   "source": [
    "## Loading pretrained movie features into non-trainable embedding layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f4c3cb90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# loading NVTabular movie encoding\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "INPUT_DATA_DIR = './data'\n",
    "movie_mapping = pd.read_parquet(os.path.join(INPUT_DATA_DIR, \"workflow-hugectr/categories/unique.movieId.parquet\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dab4fd5d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>movieId</th>\n",
       "      <th>movieId_size</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>56581</th>\n",
       "      <td>209155</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>56582</th>\n",
       "      <td>209157</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>56583</th>\n",
       "      <td>209159</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>56584</th>\n",
       "      <td>209169</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>56585</th>\n",
       "      <td>209171</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       movieId  movieId_size\n",
       "56581   209155             1\n",
       "56582   209157             1\n",
       "56583   209159             1\n",
       "56584   209169             1\n",
       "56585   209171             1"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "movie_mapping.tail()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "76b723fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(62423, 3073)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>movieId</th>\n",
       "      <th>poster_feature_0</th>\n",
       "      <th>poster_feature_1</th>\n",
       "      <th>poster_feature_2</th>\n",
       "      <th>poster_feature_3</th>\n",
       "      <th>poster_feature_4</th>\n",
       "      <th>poster_feature_5</th>\n",
       "      <th>poster_feature_6</th>\n",
       "      <th>poster_feature_7</th>\n",
       "      <th>poster_feature_8</th>\n",
       "      <th>...</th>\n",
       "      <th>text_feature_1014</th>\n",
       "      <th>text_feature_1015</th>\n",
       "      <th>text_feature_1016</th>\n",
       "      <th>text_feature_1017</th>\n",
       "      <th>text_feature_1018</th>\n",
       "      <th>text_feature_1019</th>\n",
       "      <th>text_feature_1020</th>\n",
       "      <th>text_feature_1021</th>\n",
       "      <th>text_feature_1022</th>\n",
       "      <th>text_feature_1023</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.026260</td>\n",
       "      <td>0.857608</td>\n",
       "      <td>0.410247</td>\n",
       "      <td>0.066654</td>\n",
       "      <td>0.382803</td>\n",
       "      <td>0.899998</td>\n",
       "      <td>0.511562</td>\n",
       "      <td>0.592291</td>\n",
       "      <td>0.565434</td>\n",
       "      <td>...</td>\n",
       "      <td>0.636716</td>\n",
       "      <td>0.578369</td>\n",
       "      <td>0.996169</td>\n",
       "      <td>0.402107</td>\n",
       "      <td>0.412318</td>\n",
       "      <td>0.859952</td>\n",
       "      <td>0.293852</td>\n",
       "      <td>0.341114</td>\n",
       "      <td>0.727113</td>\n",
       "      <td>0.085829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2.0</td>\n",
       "      <td>0.141265</td>\n",
       "      <td>0.721758</td>\n",
       "      <td>0.679958</td>\n",
       "      <td>0.955634</td>\n",
       "      <td>0.391091</td>\n",
       "      <td>0.324611</td>\n",
       "      <td>0.505211</td>\n",
       "      <td>0.258331</td>\n",
       "      <td>0.048264</td>\n",
       "      <td>...</td>\n",
       "      <td>0.161505</td>\n",
       "      <td>0.431864</td>\n",
       "      <td>0.836532</td>\n",
       "      <td>0.525013</td>\n",
       "      <td>0.654566</td>\n",
       "      <td>0.823841</td>\n",
       "      <td>0.818313</td>\n",
       "      <td>0.856280</td>\n",
       "      <td>0.638048</td>\n",
       "      <td>0.685537</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3.0</td>\n",
       "      <td>0.119418</td>\n",
       "      <td>0.911146</td>\n",
       "      <td>0.470762</td>\n",
       "      <td>0.762258</td>\n",
       "      <td>0.626335</td>\n",
       "      <td>0.768947</td>\n",
       "      <td>0.241833</td>\n",
       "      <td>0.775992</td>\n",
       "      <td>0.236340</td>\n",
       "      <td>...</td>\n",
       "      <td>0.865548</td>\n",
       "      <td>0.387806</td>\n",
       "      <td>0.668321</td>\n",
       "      <td>0.552122</td>\n",
       "      <td>0.750238</td>\n",
       "      <td>0.863707</td>\n",
       "      <td>0.382173</td>\n",
       "      <td>0.894487</td>\n",
       "      <td>0.565142</td>\n",
       "      <td>0.164083</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.0</td>\n",
       "      <td>0.538184</td>\n",
       "      <td>0.980678</td>\n",
       "      <td>0.643513</td>\n",
       "      <td>0.928519</td>\n",
       "      <td>0.794906</td>\n",
       "      <td>0.201022</td>\n",
       "      <td>0.744666</td>\n",
       "      <td>0.962188</td>\n",
       "      <td>0.915320</td>\n",
       "      <td>...</td>\n",
       "      <td>0.777534</td>\n",
       "      <td>0.904200</td>\n",
       "      <td>0.167337</td>\n",
       "      <td>0.875194</td>\n",
       "      <td>0.180481</td>\n",
       "      <td>0.815904</td>\n",
       "      <td>0.808288</td>\n",
       "      <td>0.036711</td>\n",
       "      <td>0.902779</td>\n",
       "      <td>0.580946</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>0.772951</td>\n",
       "      <td>0.239788</td>\n",
       "      <td>0.061874</td>\n",
       "      <td>0.162997</td>\n",
       "      <td>0.388310</td>\n",
       "      <td>0.236311</td>\n",
       "      <td>0.162757</td>\n",
       "      <td>0.207134</td>\n",
       "      <td>0.111078</td>\n",
       "      <td>...</td>\n",
       "      <td>0.250022</td>\n",
       "      <td>0.335043</td>\n",
       "      <td>0.091674</td>\n",
       "      <td>0.121507</td>\n",
       "      <td>0.418124</td>\n",
       "      <td>0.150020</td>\n",
       "      <td>0.803506</td>\n",
       "      <td>0.059504</td>\n",
       "      <td>0.002342</td>\n",
       "      <td>0.932321</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 3073 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   movieId  poster_feature_0  poster_feature_1  poster_feature_2  \\\n",
       "0      1.0          0.026260          0.857608          0.410247   \n",
       "1      2.0          0.141265          0.721758          0.679958   \n",
       "2      3.0          0.119418          0.911146          0.470762   \n",
       "3      4.0          0.538184          0.980678          0.643513   \n",
       "4      5.0          0.772951          0.239788          0.061874   \n",
       "\n",
       "   poster_feature_3  poster_feature_4  poster_feature_5  poster_feature_6  \\\n",
       "0          0.066654          0.382803          0.899998          0.511562   \n",
       "1          0.955634          0.391091          0.324611          0.505211   \n",
       "2          0.762258          0.626335          0.768947          0.241833   \n",
       "3          0.928519          0.794906          0.201022          0.744666   \n",
       "4          0.162997          0.388310          0.236311          0.162757   \n",
       "\n",
       "   poster_feature_7  poster_feature_8  ...  text_feature_1014  \\\n",
       "0          0.592291          0.565434  ...           0.636716   \n",
       "1          0.258331          0.048264  ...           0.161505   \n",
       "2          0.775992          0.236340  ...           0.865548   \n",
       "3          0.962188          0.915320  ...           0.777534   \n",
       "4          0.207134          0.111078  ...           0.250022   \n",
       "\n",
       "   text_feature_1015  text_feature_1016  text_feature_1017  text_feature_1018  \\\n",
       "0           0.578369           0.996169           0.402107           0.412318   \n",
       "1           0.431864           0.836532           0.525013           0.654566   \n",
       "2           0.387806           0.668321           0.552122           0.750238   \n",
       "3           0.904200           0.167337           0.875194           0.180481   \n",
       "4           0.335043           0.091674           0.121507           0.418124   \n",
       "\n",
       "   text_feature_1019  text_feature_1020  text_feature_1021  text_feature_1022  \\\n",
       "0           0.859952           0.293852           0.341114           0.727113   \n",
       "1           0.823841           0.818313           0.856280           0.638048   \n",
       "2           0.863707           0.382173           0.894487           0.565142   \n",
       "3           0.815904           0.808288           0.036711           0.902779   \n",
       "4           0.150020           0.803506           0.059504           0.002342   \n",
       "\n",
       "   text_feature_1023  \n",
       "0           0.085829  \n",
       "1           0.685537  \n",
       "2           0.164083  \n",
       "3           0.580946  \n",
       "4           0.932321  \n",
       "\n",
       "[5 rows x 3073 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "feature_df = pd.read_parquet('feature_df.parquet')\n",
    "print(feature_df.shape)\n",
    "feature_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0cbdea3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_df.set_index('movieId', inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6e17c0fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading pretrained embedding matrix...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████| 56586/56586 [00:14<00:00, 3967.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found features for 56585 movies (1 misses)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "num_tokens = len(movie_mapping)\n",
    "embedding_dim = 2048+1024\n",
    "hits = 0\n",
    "misses = 0\n",
    "\n",
    "# Prepare embedding matrix\n",
    "embedding_matrix = np.zeros((num_tokens, embedding_dim))\n",
    "\n",
    "print(\"Loading pretrained embedding matrix...\")\n",
    "for i, row in tqdm(movie_mapping.iterrows(), total=len(movie_mapping)):\n",
    "    movieId = row['movieId']\n",
    "    if movieId in feature_df.index: \n",
    "        embedding_vector = feature_df.loc[movieId]\n",
    "        # embedding found\n",
    "        embedding_matrix[i] = embedding_vector\n",
    "        hits += 1\n",
    "    else:\n",
    "        misses += 1\n",
    "print(\"Found features for %d movies (%d misses)\" % (hits, misses))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f9b09cb4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3072"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "844e668c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,\n",
       "        0.        ],\n",
       "       [0.17294852, 0.15285189, 0.26095702, ..., 0.75369112, 0.29602144,\n",
       "        0.78917433],\n",
       "       [0.13539355, 0.84843078, 0.70951219, ..., 0.10441725, 0.72871966,\n",
       "        0.11719463],\n",
       "       ...,\n",
       "       [0.18514273, 0.72422918, 0.04273015, ..., 0.1404219 , 0.54169348,\n",
       "        0.96875489],\n",
       "       [0.08307642, 0.3673532 , 0.15777258, ..., 0.01297393, 0.36267638,\n",
       "        0.14848055],\n",
       "       [0.82188376, 0.56516905, 0.70838085, ..., 0.45119769, 0.9273439 ,\n",
       "        0.42464321]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_matrix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e75526a",
   "metadata": {},
   "source": [
    "Next, we write the pretrained embedding to a raw format supported by HugeCTR.\n",
    "\n",
    "Note: As of version 3.2, HugeCTR only supports a maximum embedding size of 1024. Hence, we shall be using the first 512 element of image embedding plus 512 element of text embedding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e04070e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import struct\n",
    "\n",
    "PRETRAINED_EMBEDDING_SIZE = 1024\n",
    "\n",
    "def convert_pretrained_embeddings_to_sparse_model(keys, pre_trained_sparse_embeddings, hugectr_sparse_model, embedding_vec_size):\n",
    "    os.system(\"mkdir -p {}\".format(hugectr_sparse_model))\n",
    "    with open(\"{}/key\".format(hugectr_sparse_model), 'wb') as key_file, \\\n",
    "        open(\"{}/emb_vector\".format(hugectr_sparse_model), 'wb') as vec_file:\n",
    "                \n",
    "        for i, key in enumerate(keys):\n",
    "            vec = np.concatenate([pre_trained_sparse_embeddings[i,:int(PRETRAINED_EMBEDDING_SIZE/2)], pre_trained_sparse_embeddings[i, 1024:1024+int(PRETRAINED_EMBEDDING_SIZE/2)]])\n",
    "            key_struct = struct.pack('q', key)\n",
    "            vec_struct = struct.pack(str(embedding_vec_size) + \"f\", *vec)\n",
    "            key_file.write(key_struct)\n",
    "            vec_file.write(vec_struct)\n",
    "\n",
    "keys = list(movie_mapping.index)\n",
    "convert_pretrained_embeddings_to_sparse_model(keys, embedding_matrix, 'hugectr_pretrained_embedding.model', embedding_vec_size=PRETRAINED_EMBEDDING_SIZE) # HugeCTR not supporting embedding size > 1024"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be386068",
   "metadata": {},
   "source": [
    "## Define and train model\n",
    "\n",
    "In this section, we define and train the model. The model comprise trainable embedding layers for categorical features (`userId`, `movieId`) and pretrained (non-trainable) embedding layer for movie features.\n",
    "\n",
    "We will write the model to `./model.py` and execute it afterwards."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0431c33",
   "metadata": {},
   "source": [
    "First, we need the cardinalities of each categorical feature to assign as `slot_size_array` in the model below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8e79396c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'userId': (162542, 512), 'movieId': (56586, 512), 'movieId_duplicate': (56586, 512)}\n"
     ]
    }
   ],
   "source": [
    "import nvtabular as nvt\n",
    "from nvtabular.ops import get_embedding_sizes\n",
    "\n",
    "workflow = nvt.Workflow.load(os.path.join(INPUT_DATA_DIR, \"workflow-hugectr\"))\n",
    "\n",
    "embeddings = get_embedding_sizes(workflow)\n",
    "print(embeddings)\n",
    "\n",
    "#{'userId': (162542, 512), 'movieId': (56586, 512), 'movieId_duplicate': (56586, 512)}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8f9f3f4",
   "metadata": {},
   "source": [
    "We use `graph_to_json` to convert the model to a JSON configuration, required for the inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a3b4d917",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting ./model.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile './model.py'\n",
    "\n",
    "import hugectr\n",
    "from mpi4py import MPI  # noqa\n",
    "INPUT_DATA_DIR = './data/'\n",
    "\n",
    "solver = hugectr.CreateSolver(\n",
    "    vvgpu=[[0]],\n",
    "    batchsize=2048,\n",
    "    batchsize_eval=2048,\n",
    "    max_eval_batches=160,\n",
    "    i64_input_key=True,\n",
    "    use_mixed_precision=False,\n",
    "    repeat_dataset=True,\n",
    ")\n",
    "optimizer = hugectr.CreateOptimizer(optimizer_type=hugectr.Optimizer_t.Adam)\n",
    "reader = hugectr.DataReaderParams(\n",
    "    data_reader_type=hugectr.DataReaderType_t.Parquet,\n",
    "    source=[INPUT_DATA_DIR + \"train-hugectr/_file_list.txt\"],\n",
    "    eval_source=INPUT_DATA_DIR + \"valid-hugectr/_file_list.txt\",\n",
    "    check_type=hugectr.Check_t.Non,\n",
    "    slot_size_array=[162542, 56586, 21, 56586],\n",
    ")\n",
    "\n",
    "model = hugectr.Model(solver, reader, optimizer)\n",
    "\n",
    "model.add(\n",
    "    hugectr.Input(\n",
    "        label_dim=1,\n",
    "        label_name=\"label\",\n",
    "        dense_dim=0,\n",
    "        dense_name=\"dense\",\n",
    "        data_reader_sparse_param_array=[\n",
    "            hugectr.DataReaderSparseParam(\"data1\", nnz_per_slot=[1, 1, 2], is_fixed_length=False, slot_num=3),\n",
    "            hugectr.DataReaderSparseParam(\"movieId\", nnz_per_slot=[1], is_fixed_length=True, slot_num=1)\n",
    "        ],\n",
    "    )\n",
    ")\n",
    "model.add(\n",
    "    hugectr.SparseEmbedding(\n",
    "        embedding_type=hugectr.Embedding_t.LocalizedSlotSparseEmbeddingHash,\n",
    "        workspace_size_per_gpu_in_mb=3000,\n",
    "        embedding_vec_size=16,\n",
    "        combiner=\"sum\",\n",
    "        sparse_embedding_name=\"sparse_embedding1\",\n",
    "        bottom_name=\"data1\",\n",
    "        optimizer=optimizer,\n",
    "    )\n",
    ")\n",
    "\n",
    "# pretrained embedding\n",
    "model.add(\n",
    "    hugectr.SparseEmbedding(\n",
    "        embedding_type=hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,\n",
    "        workspace_size_per_gpu_in_mb=3000,\n",
    "        embedding_vec_size=1024,\n",
    "        combiner=\"sum\",\n",
    "        sparse_embedding_name=\"pretrained_embedding\",\n",
    "        bottom_name=\"movieId\",\n",
    "        optimizer=optimizer,\n",
    "    )\n",
    ")\n",
    "\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,\n",
    "                            bottom_names = [\"sparse_embedding1\"],\n",
    "                            top_names = [\"reshape1\"],\n",
    "                            leading_dim=48))\n",
    "\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Reshape,\n",
    "                            bottom_names = [\"pretrained_embedding\"],\n",
    "                            top_names = [\"reshape2\"],\n",
    "                            leading_dim=1024))\n",
    "\n",
    "model.add(hugectr.DenseLayer(layer_type = hugectr.Layer_t.Concat,\n",
    "                            bottom_names = [\"reshape1\", \"reshape2\"],\n",
    "                            top_names = [\"concat1\"]))\n",
    "\n",
    "model.add(\n",
    "    hugectr.DenseLayer(\n",
    "        layer_type=hugectr.Layer_t.InnerProduct,\n",
    "        bottom_names=[\"concat1\"],\n",
    "        top_names=[\"fc1\"],\n",
    "        num_output=128,\n",
    "    )\n",
    ")\n",
    "model.add(\n",
    "    hugectr.DenseLayer(\n",
    "        layer_type=hugectr.Layer_t.ReLU,\n",
    "        bottom_names=[\"fc1\"],\n",
    "        top_names=[\"relu1\"],\n",
    "    )\n",
    ")\n",
    "model.add(\n",
    "    hugectr.DenseLayer(\n",
    "        layer_type=hugectr.Layer_t.InnerProduct,\n",
    "        bottom_names=[\"relu1\"],\n",
    "        top_names=[\"fc2\"],\n",
    "        num_output=128,\n",
    "    )\n",
    ")\n",
    "model.add(\n",
    "    hugectr.DenseLayer(\n",
    "        layer_type=hugectr.Layer_t.ReLU,\n",
    "        bottom_names=[\"fc2\"],\n",
    "        top_names=[\"relu2\"],\n",
    "    )\n",
    ")\n",
    "model.add(\n",
    "    hugectr.DenseLayer(\n",
    "        layer_type=hugectr.Layer_t.InnerProduct,\n",
    "        bottom_names=[\"relu2\"],\n",
    "        top_names=[\"fc3\"],\n",
    "        num_output=1,\n",
    "    )\n",
    ")\n",
    "model.add(\n",
    "    hugectr.DenseLayer(\n",
    "        layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss,\n",
    "        bottom_names=[\"fc3\", \"label\"],\n",
    "        top_names=[\"loss\"],\n",
    "    )\n",
    ")\n",
    "model.compile()\n",
    "model.summary()\n",
    "\n",
    "# Load the pretrained embedding layer\n",
    "model.load_sparse_weights({\"pretrained_embedding\": \"./hugectr_pretrained_embedding.model\"})\n",
    "model.freeze_embedding(\"pretrained_embedding\")\n",
    "\n",
    "model.fit(max_iter=10001, display=100, eval_interval=200, snapshot=5000)\n",
    "model.graph_to_json(graph_config_file=\"hugectr-movielens.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b61a6196",
   "metadata": {},
   "source": [
    "We train our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20e69a4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HugeCTR Version: 3.2\n",
      "====================================================Model Init=====================================================\n",
      "[HUGECTR][01:09:00][INFO][RANK0]: Global seed is 476440390\n",
      "[HUGECTR][01:09:00][INFO][RANK0]: Device to NUMA mapping:\n",
      "  GPU 0 ->  node 0\n",
      "\n",
      "[HUGECTR][01:09:01][WARNING][RANK0]: Peer-to-peer access cannot be fully enabled.\n",
      "[HUGECTR][01:09:01][INFO][RANK0]: Start all2all warmup\n",
      "[HUGECTR][01:09:01][INFO][RANK0]: End all2all warmup\n",
      "[HUGECTR][01:09:01][INFO][RANK0]: Using All-reduce algorithm: NCCL\n",
      "[HUGECTR][01:09:01][INFO][RANK0]: Device 0: Tesla V100-SXM2-16GB\n",
      "[HUGECTR][01:09:01][INFO][RANK0]: num of DataReader workers: 1\n",
      "[HUGECTR][01:09:01][INFO][RANK0]: Vocabulary size: 275735\n",
      "[HUGECTR][01:09:01][INFO][RANK0]: max_vocabulary_size_per_gpu_=16384000\n",
      "[HUGECTR][01:09:01][INFO][RANK0]: max_vocabulary_size_per_gpu_=256000\n",
      "[HUGECTR][01:09:01][INFO][RANK0]: Graph analysis to resolve tensor dependency\n",
      "===================================================Model Compile===================================================\n",
      "[HUGECTR][01:09:04][INFO][RANK0]: gpu0 start to init embedding\n",
      "[HUGECTR][01:09:04][INFO][RANK0]: gpu0 init embedding done\n",
      "[HUGECTR][01:09:04][INFO][RANK0]: gpu0 start to init embedding\n",
      "[HUGECTR][01:09:04][INFO][RANK0]: gpu0 init embedding done\n",
      "[HUGECTR][01:09:04][INFO][RANK0]: Starting AUC NCCL warm-up\n",
      "[HUGECTR][01:09:04][INFO][RANK0]: Warm-up done\n",
      "===================================================Model Summary===================================================\n",
      "label                                   Dense                         Sparse                        \n",
      "label                                   dense                          data1,movieId                 \n",
      "(None, 1)                               (None, 0)                               \n",
      "——————————————————————————————————————————————————————————————————————————————————————————————————————————————————\n",
      "Layer Type                              Input Name                    Output Name                   Output Shape                  \n",
      "——————————————————————————————————————————————————————————————————————————————————————————————————————————————————\n",
      "LocalizedSlotSparseEmbeddingHash        data1                         sparse_embedding1             (None, 3, 16)                 \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "DistributedSlotSparseEmbeddingHash      movieId                       pretrained_embedding          (None, 1, 1024)               \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "Reshape                                 sparse_embedding1             reshape1                      (None, 48)                    \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "Reshape                                 pretrained_embedding          reshape2                      (None, 1024)                  \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "Concat                                  reshape1                      concat1                       (None, 1072)                  \n",
      "                                        reshape2                                                                                  \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "InnerProduct                            concat1                       fc1                           (None, 128)                   \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "ReLU                                    fc1                           relu1                         (None, 128)                   \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "InnerProduct                            relu1                         fc2                           (None, 128)                   \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "ReLU                                    fc2                           relu2                         (None, 128)                   \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "InnerProduct                            relu2                         fc3                           (None, 1)                     \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "BinaryCrossEntropyLoss                  fc3                           loss                                                        \n",
      "                                        label                                                                                     \n",
      "------------------------------------------------------------------------------------------------------------------\n",
      "[HUGECTR][01:09:04][INFO][RANK0]: Loading sparse model: ./hugectr_pretrained_embedding.model\n",
      "=====================================================Model Fit=====================================================\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: Use non-epoch mode with number of iterations: 10001\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: Training batchsize: 2048, evaluation batchsize: 2048\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: Evaluation interval: 200, snapshot interval: 5000\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: Dense network trainable: True\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: Sparse embedding pretrained_embedding trainable: False\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: Sparse embedding sparse_embedding1 trainable: True\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: Use mixed precision: False, scaler: 1.000000, use cuda graph: True\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: lr: 0.001000, warmup_steps: 1, end_lr: 0.000000\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: decay_start: 0, decay_steps: 1, decay_power: 2.000000\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: Training source file: ./data/train-hugectr/_file_list.txt\n",
      "[HUGECTR][01:09:06][INFO][RANK0]: Evaluation source file: ./data/valid-hugectr/_file_list.txt\n",
      "[HUGECTR][01:09:08][INFO][RANK0]: Iter: 100 Time(100 iters): 2.297110s Loss: 0.581705 lr:0.001000\n",
      "[HUGECTR][01:09:11][INFO][RANK0]: Iter: 200 Time(100 iters): 2.274680s Loss: 0.574425 lr:0.001000\n",
      "[HUGECTR][01:09:11][INFO][RANK0]: Evaluation, AUC: 0.746443\n",
      "[HUGECTR][01:09:11][INFO][RANK0]: Eval Time for 160 iters: 0.054157s\n",
      "[HUGECTR][01:09:13][INFO][RANK0]: Iter: 300 Time(100 iters): 2.332273s Loss: 0.564224 lr:0.001000\n",
      "[HUGECTR][01:09:15][INFO][RANK0]: Iter: 400 Time(100 iters): 2.277900s Loss: 0.550730 lr:0.001000\n",
      "[HUGECTR][01:09:15][INFO][RANK0]: Evaluation, AUC: 0.764630\n",
      "[HUGECTR][01:09:15][INFO][RANK0]: Eval Time for 160 iters: 0.054009s\n",
      "[HUGECTR][01:09:18][INFO][RANK0]: Iter: 500 Time(100 iters): 2.434429s Loss: 0.536507 lr:0.001000\n",
      "[HUGECTR][01:09:20][INFO][RANK0]: Iter: 600 Time(100 iters): 2.279014s Loss: 0.525059 lr:0.001000\n",
      "[HUGECTR][01:09:20][INFO][RANK0]: Evaluation, AUC: 0.773702\n",
      "[HUGECTR][01:09:20][INFO][RANK0]: Eval Time for 160 iters: 0.054287s\n",
      "[HUGECTR][01:09:22][INFO][RANK0]: Iter: 700 Time(100 iters): 2.335757s Loss: 0.532503 lr:0.001000\n",
      "[HUGECTR][01:09:25][INFO][RANK0]: Iter: 800 Time(100 iters): 2.278661s Loss: 0.526352 lr:0.001000\n",
      "[HUGECTR][01:09:25][INFO][RANK0]: Evaluation, AUC: 0.779897\n",
      "[HUGECTR][01:09:25][INFO][RANK0]: Eval Time for 160 iters: 0.167787s\n",
      "[HUGECTR][01:09:27][INFO][RANK0]: Iter: 900 Time(100 iters): 2.447136s Loss: 0.547141 lr:0.001000\n",
      "[HUGECTR][01:09:29][INFO][RANK0]: Iter: 1000 Time(100 iters): 2.376035s Loss: 0.548916 lr:0.001000\n",
      "[HUGECTR][01:09:30][INFO][RANK0]: Evaluation, AUC: 0.784775\n",
      "[HUGECTR][01:09:30][INFO][RANK0]: Eval Time for 160 iters: 0.054224s\n",
      "[HUGECTR][01:09:32][INFO][RANK0]: Iter: 1100 Time(100 iters): 2.334735s Loss: 0.540766 lr:0.001000\n",
      "[HUGECTR][01:09:34][INFO][RANK0]: Iter: 1200 Time(100 iters): 2.277728s Loss: 0.515882 lr:0.001000\n",
      "[HUGECTR][01:09:34][INFO][RANK0]: Evaluation, AUC: 0.786808\n",
      "[HUGECTR][01:09:34][INFO][RANK0]: Eval Time for 160 iters: 0.054551s\n",
      "[HUGECTR][01:09:36][INFO][RANK0]: Iter: 1300 Time(100 iters): 2.336372s Loss: 0.531510 lr:0.001000\n",
      "[HUGECTR][01:09:39][INFO][RANK0]: Iter: 1400 Time(100 iters): 2.277408s Loss: 0.511901 lr:0.001000\n",
      "[HUGECTR][01:09:39][INFO][RANK0]: Evaluation, AUC: 0.791416\n",
      "[HUGECTR][01:09:39][INFO][RANK0]: Eval Time for 160 iters: 0.165986s\n",
      "[HUGECTR][01:09:41][INFO][RANK0]: Iter: 1500 Time(100 iters): 2.554217s Loss: 0.522047 lr:0.001000\n",
      "[HUGECTR][01:09:44][INFO][RANK0]: Iter: 1600 Time(100 iters): 2.279548s Loss: 0.540521 lr:0.001000\n",
      "[HUGECTR][01:09:44][INFO][RANK0]: Evaluation, AUC: 0.793460\n",
      "[HUGECTR][01:09:44][INFO][RANK0]: Eval Time for 160 iters: 0.054801s\n",
      "[HUGECTR][01:09:46][INFO][RANK0]: Iter: 1700 Time(100 iters): 2.336303s Loss: 0.525447 lr:0.001000\n",
      "[HUGECTR][01:09:48][INFO][RANK0]: Iter: 1800 Time(100 iters): 2.278906s Loss: 0.523558 lr:0.001000\n",
      "[HUGECTR][01:09:48][INFO][RANK0]: Evaluation, AUC: 0.793137\n",
      "[HUGECTR][01:09:48][INFO][RANK0]: Eval Time for 160 iters: 0.054431s\n",
      "[HUGECTR][01:09:51][INFO][RANK0]: Iter: 1900 Time(100 iters): 2.336023s Loss: 0.511348 lr:0.001000\n",
      "[HUGECTR][01:09:53][INFO][RANK0]: Iter: 2000 Time(100 iters): 2.384979s Loss: 0.515268 lr:0.001000\n",
      "[HUGECTR][01:09:53][INFO][RANK0]: Evaluation, AUC: 0.796599\n",
      "[HUGECTR][01:09:53][INFO][RANK0]: Eval Time for 160 iters: 0.172160s\n",
      "[HUGECTR][01:09:55][INFO][RANK0]: Iter: 2100 Time(100 iters): 2.453174s Loss: 0.526615 lr:0.001000\n",
      "[HUGECTR][01:09:58][INFO][RANK0]: Iter: 2200 Time(100 iters): 2.278781s Loss: 0.536789 lr:0.001000\n",
      "[HUGECTR][01:09:58][INFO][RANK0]: Evaluation, AUC: 0.798459\n",
      "[HUGECTR][01:09:58][INFO][RANK0]: Eval Time for 160 iters: 0.054509s\n",
      "[HUGECTR][01:10:00][INFO][RANK0]: Iter: 2300 Time(100 iters): 2.335596s Loss: 0.508902 lr:0.001000\n",
      "[HUGECTR][01:10:02][INFO][RANK0]: Iter: 2400 Time(100 iters): 2.277901s Loss: 0.520411 lr:0.001000\n",
      "[HUGECTR][01:10:02][INFO][RANK0]: Evaluation, AUC: 0.798726\n",
      "[HUGECTR][01:10:02][INFO][RANK0]: Eval Time for 160 iters: 0.054518s\n",
      "[HUGECTR][01:10:05][INFO][RANK0]: Iter: 2500 Time(100 iters): 2.444557s Loss: 0.490832 lr:0.001000\n",
      "[HUGECTR][01:10:07][INFO][RANK0]: Iter: 2600 Time(100 iters): 2.279310s Loss: 0.507799 lr:0.001000\n",
      "[HUGECTR][01:10:07][INFO][RANK0]: Evaluation, AUC: 0.801325\n",
      "[HUGECTR][01:10:07][INFO][RANK0]: Eval Time for 160 iters: 0.164203s\n",
      "[HUGECTR][01:10:10][INFO][RANK0]: Iter: 2700 Time(100 iters): 2.443310s Loss: 0.519460 lr:0.001000\n",
      "[HUGECTR][01:10:12][INFO][RANK0]: Iter: 2800 Time(100 iters): 2.277569s Loss: 0.512426 lr:0.001000\n",
      "[HUGECTR][01:10:12][INFO][RANK0]: Evaluation, AUC: 0.800731\n",
      "[HUGECTR][01:10:12][INFO][RANK0]: Eval Time for 160 iters: 0.054590s\n",
      "[HUGECTR][01:10:14][INFO][RANK0]: Iter: 2900 Time(100 iters): 2.336213s Loss: 0.512216 lr:0.001000\n",
      "[HUGECTR][01:10:17][INFO][RANK0]: Iter: 3000 Time(100 iters): 2.384833s Loss: 0.522102 lr:0.001000\n",
      "[HUGECTR][01:10:17][INFO][RANK0]: Evaluation, AUC: 0.803801\n",
      "[HUGECTR][01:10:17][INFO][RANK0]: Eval Time for 160 iters: 0.054133s\n",
      "[HUGECTR][01:10:19][INFO][RANK0]: Iter: 3100 Time(100 iters): 2.334245s Loss: 0.507463 lr:0.001000\n",
      "[HUGECTR][01:10:21][INFO][RANK0]: Iter: 3200 Time(100 iters): 2.279046s Loss: 0.526148 lr:0.001000\n",
      "[HUGECTR][01:10:21][INFO][RANK0]: Evaluation, AUC: 0.802950\n",
      "[HUGECTR][01:10:21][INFO][RANK0]: Eval Time for 160 iters: 0.070003s\n",
      "[HUGECTR][01:10:24][INFO][RANK0]: Iter: 3300 Time(100 iters): 2.352114s Loss: 0.504611 lr:0.001000\n",
      "[HUGECTR][01:10:26][INFO][RANK0]: Iter: 3400 Time(100 iters): 2.277292s Loss: 0.502907 lr:0.001000\n",
      "[HUGECTR][01:10:26][INFO][RANK0]: Evaluation, AUC: 0.804364\n",
      "[HUGECTR][01:10:26][INFO][RANK0]: Eval Time for 160 iters: 0.054315s\n",
      "[HUGECTR][01:10:28][INFO][RANK0]: Iter: 3500 Time(100 iters): 2.442956s Loss: 0.512927 lr:0.001000\n",
      "[HUGECTR][01:10:31][INFO][RANK0]: Iter: 3600 Time(100 iters): 2.277974s Loss: 0.519042 lr:0.001000\n",
      "[HUGECTR][01:10:31][INFO][RANK0]: Evaluation, AUC: 0.806404\n",
      "[HUGECTR][01:10:31][INFO][RANK0]: Eval Time for 160 iters: 0.054291s\n",
      "[HUGECTR][01:10:33][INFO][RANK0]: Iter: 3700 Time(100 iters): 2.335365s Loss: 0.499368 lr:0.001000\n",
      "[HUGECTR][01:10:35][INFO][RANK0]: Iter: 3800 Time(100 iters): 2.277786s Loss: 0.509683 lr:0.001000\n",
      "[HUGECTR][01:10:35][INFO][RANK0]: Evaluation, AUC: 0.805164\n",
      "[HUGECTR][01:10:35][INFO][RANK0]: Eval Time for 160 iters: 0.064908s\n",
      "[HUGECTR][01:10:38][INFO][RANK0]: Iter: 3900 Time(100 iters): 2.344106s Loss: 0.508182 lr:0.001000\n",
      "[HUGECTR][01:10:40][INFO][RANK0]: Iter: 4000 Time(100 iters): 2.387872s Loss: 0.493841 lr:0.001000\n",
      "[HUGECTR][01:10:40][INFO][RANK0]: Evaluation, AUC: 0.808367\n",
      "[HUGECTR][01:10:40][INFO][RANK0]: Eval Time for 160 iters: 0.054222s\n",
      "[HUGECTR][01:10:42][INFO][RANK0]: Iter: 4100 Time(100 iters): 2.335361s Loss: 0.508106 lr:0.001000\n",
      "[HUGECTR][01:10:45][INFO][RANK0]: Iter: 4200 Time(100 iters): 2.278802s Loss: 0.519000 lr:0.001000\n",
      "[HUGECTR][01:10:45][INFO][RANK0]: Evaluation, AUC: 0.808897\n",
      "[HUGECTR][01:10:45][INFO][RANK0]: Eval Time for 160 iters: 0.054320s\n",
      "[HUGECTR][01:10:47][INFO][RANK0]: Iter: 4300 Time(100 iters): 2.334094s Loss: 0.502797 lr:0.001000\n",
      "[HUGECTR][01:10:49][INFO][RANK0]: Iter: 4400 Time(100 iters): 2.388990s Loss: 0.508890 lr:0.001000\n",
      "[HUGECTR][01:10:49][INFO][RANK0]: Evaluation, AUC: 0.809649\n",
      "[HUGECTR][01:10:49][INFO][RANK0]: Eval Time for 160 iters: 0.074584s\n",
      "[HUGECTR][01:10:52][INFO][RANK0]: Iter: 4500 Time(100 iters): 2.355005s Loss: 0.505778 lr:0.001000\n",
      "[HUGECTR][01:10:54][INFO][RANK0]: Iter: 4600 Time(100 iters): 2.277275s Loss: 0.532776 lr:0.001000\n",
      "[HUGECTR][01:10:54][INFO][RANK0]: Evaluation, AUC: 0.810962\n",
      "[HUGECTR][01:10:54][INFO][RANK0]: Eval Time for 160 iters: 0.054498s\n",
      "[HUGECTR][01:10:56][INFO][RANK0]: Iter: 4700 Time(100 iters): 2.335553s Loss: 0.503001 lr:0.001000\n",
      "[HUGECTR][01:10:59][INFO][RANK0]: Iter: 4800 Time(100 iters): 2.279237s Loss: 0.495762 lr:0.001000\n",
      "[HUGECTR][01:10:59][INFO][RANK0]: Evaluation, AUC: 0.808618\n",
      "[HUGECTR][01:10:59][INFO][RANK0]: Eval Time for 160 iters: 0.054287s\n",
      "[HUGECTR][01:11:01][INFO][RANK0]: Iter: 4900 Time(100 iters): 2.449926s Loss: 0.503213 lr:0.001000\n",
      "[HUGECTR][01:11:03][INFO][RANK0]: Iter: 5000 Time(100 iters): 2.277141s Loss: 0.481138 lr:0.001000\n",
      "[HUGECTR][01:11:03][INFO][RANK0]: Evaluation, AUC: 0.810767\n",
      "[HUGECTR][01:11:03][INFO][RANK0]: Eval Time for 160 iters: 0.064807s\n",
      "[HUGECTR][01:11:04][INFO][RANK0]: Rank0: Dump hash table from GPU0\n",
      "[HUGECTR][01:11:04][INFO][RANK0]: Rank0: Write hash table <key,value> pairs to file\n",
      "[HUGECTR][01:11:04][INFO][RANK0]: Done\n",
      "[HUGECTR][01:11:04][INFO][RANK0]: Rank0: Write hash table to file\n",
      "[HUGECTR][01:11:13][INFO][RANK0]: Dumping sparse weights to files, successful\n",
      "[HUGECTR][01:11:13][INFO][RANK0]: Rank0: Write optimzer state to file\n",
      "[HUGECTR][01:11:14][INFO][RANK0]: Done\n",
      "[HUGECTR][01:11:14][INFO][RANK0]: Rank0: Write optimzer state to file\n",
      "[HUGECTR][01:11:15][INFO][RANK0]: Done\n",
      "[HUGECTR][01:11:34][INFO][RANK0]: Rank0: Write optimzer state to file\n",
      "[HUGECTR][01:11:35][INFO][RANK0]: Done\n",
      "[HUGECTR][01:11:35][INFO][RANK0]: Rank0: Write optimzer state to file\n",
      "[HUGECTR][01:11:36][INFO][RANK0]: Done\n",
      "[HUGECTR][01:11:55][INFO][RANK0]: Dumping sparse optimzer states to files, successful\n",
      "[HUGECTR][01:11:55][INFO][RANK0]: Dumping dense weights to file, successful\n",
      "[HUGECTR][01:11:55][INFO][RANK0]: Dumping dense optimizer states to file, successful\n",
      "[HUGECTR][01:11:55][INFO][RANK0]: Dumping untrainable weights to file, successful\n",
      "[HUGECTR][01:11:57][INFO][RANK0]: Iter: 5100 Time(100 iters): 53.630313s Loss: 0.485568 lr:0.001000\n",
      "[HUGECTR][01:11:59][INFO][RANK0]: Iter: 5200 Time(100 iters): 2.278359s Loss: 0.518924 lr:0.001000\n",
      "[HUGECTR][01:11:59][INFO][RANK0]: Evaluation, AUC: 0.811217\n",
      "[HUGECTR][01:11:59][INFO][RANK0]: Eval Time for 160 iters: 0.054624s\n",
      "[HUGECTR][01:12:02][INFO][RANK0]: Iter: 5300 Time(100 iters): 2.336246s Loss: 0.516505 lr:0.001000\n",
      "[HUGECTR][01:12:04][INFO][RANK0]: Iter: 5400 Time(100 iters): 2.384571s Loss: 0.512404 lr:0.001000\n",
      "[HUGECTR][01:12:04][INFO][RANK0]: Evaluation, AUC: 0.811464\n",
      "[HUGECTR][01:12:04][INFO][RANK0]: Eval Time for 160 iters: 0.054350s\n",
      "[HUGECTR][01:12:06][INFO][RANK0]: Iter: 5500 Time(100 iters): 2.334675s Loss: 0.500305 lr:0.001000\n",
      "[HUGECTR][01:12:09][INFO][RANK0]: Iter: 5600 Time(100 iters): 2.279563s Loss: 0.484969 lr:0.001000\n"
     ]
    }
   ],
   "source": [
    "!python model.py"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
